#%%
import numpy as np
import pandas as pd
import torch
from torch.autograd import grad
import cvxpy as cp

import matplotlib.pyplot as plt
import time

from Toy_utils import Monitor

#%%
def run(n = 100, P = 3, nIter = 1000, nTx = 1, x_lr0 = 1e-2, reg2 = 1., typeP = "quad", verbose=2):
    resultfile = f'D:/Results/ICML24/Toy_Convex/Toy_P={P}_n={n}_BVFSM2_panalty={typeP}_reg2={reg2}_xlr={x_lr0}_Tx={nTx}.csv'

    F = lambda x, y : (y[:n] - 2) @ (x - 1) + torch.sum((y[n:] + 3) ** 2)
    F2 = lambda x, y : (y[:n] - 2) @ (x - 1) + cp.sum((y[n:] + 3) ** 2)

    f = lambda x, y : .5 * torch.sum(y[:n] ** 2) - x @ y[:n] + torch.sum(y[n:])
    f2 = lambda x, y : .5 * cp.sum(y[:n] ** 2) - x @ y[:n] + cp.sum(y[n:])

    if P == 1: 
        h = lambda x : torch.sum(x)
        h2 = lambda x : cp.sum(x)
        dh = lambda x : torch.ones(n)
    elif P == 3:
        h = lambda x : torch.sum(x ** 3)
        h2 = lambda x : cp.sum(x ** 3)
        dh = lambda x: 3 * x ** 2

    g = lambda x, y : h(x) + torch.sum(y)
    g2 = lambda x, y : h2(x) + cp.sum(y)

    xopt = torch.ones(n)
    yopt = torch.cat((2 * torch.ones(n), -3 * torch.ones(n)))

    proj1 = lambda x, b : x - (sum(x) + b) / len(x)  # proj1(x) + b = 0
    yx   = lambda x, y: torch.cat((x+1, proj1(y[n:], sum(x+1)+h(x))))
    y2yx = lambda x, y: torch.norm( y - yx(x, y) )
    metric_x = lambda x, y : torch.norm(x - xopt) / torch.norm(xopt)
    metric_y = lambda x, y : y2yx(x, y) / torch.norm(yx(x, y))

    # initial guess
    xk = torch.zeros(n).requires_grad_(True)
    yk = torch.zeros(2 * n).requires_grad_(True)
    zk = torch.zeros(2 * n).requires_grad_(True)

    T  = 0 
    monitor = Monitor()
    monitor.append({
        "k": 0, "time": T,
        "F": F(xk, yk).detach().numpy(), "f": f(xk, yk).detach().numpy(), 
        "g": g(xk, yk).detach().numpy(),
        "dx": metric_x(xk, yk).detach().numpy(), "dy": metric_y(xk, yk).detach().numpy(), 
    })


    for k in range(nIter):
        # print(f"{k:4d}-Iteration")
        t0 = time.time()
        reg = 1. / (k + 1) ** 0.5 # 1.01
        # x_lr, y_lr, z_lr = decay*x_lr0, decay*y_lr0, decay*z_lr0
        x_lr = x_lr0

        if typeP == "abs":
            P_h = lambda h : 1./reg * torch.abs(h)
            P_h2 = lambda h : 1./reg * cp.abs(h)
        elif typeP == "quad":
            P_h = lambda h : 1./(2 * reg) * torch.sum(h ** 2)
            P_h2 = lambda h : 1./(2 * reg) * cp.sum(h ** 2)

        # for tx in range(nTx):
        zz = cp.Variable(2 * n)
        xx = cp.Parameter(n)
        fzz = f2(xx, zz) + P_h2(g2(xx, zz)) + reg2 * reg / 2 * cp.sum_squares(zz) 
        subz = cp.Problem(cp.Minimize(fzz))

        xx.value = xk.detach().numpy() #.astype(np.float64)
        subz.solve(cp.CLARABEL, verbose=False)
        zk.data = torch.tensor(zz.value).float()

        yy = cp.Variable(2 * n)
        xx = cp.Parameter(n)
        zz = cp.Parameter(2 * n)
        fxz = cp.Parameter()
        fyy = F2(xx, yy) + P_h2(g2(xx, yy)) + P_h2( cp.maximum(f2(xx, yy) - fxz, 0) ) + reg2 * reg / 2 * cp.sum_squares(yy) 

        xx.value = xk.detach().numpy() #.astype(np.float64)
        zz.value = zk.detach().numpy() #.astype(np.float64)
        fxz.value = f(xk, zk).detach().numpy()
        suby = cp.Problem(cp.Minimize(fyy))
        try:
            suby.solve(cp.CLARABEL, verbose=False)
            yk.data = torch.tensor(yy.value).float()
        except:
            break
        
        fz = f(xk, zk.detach()) + P_h(g(xk, zk.detach())) + reg2 * reg / 2 * torch.sum(zk.detach() ** 2)
        phik = F(xk, yk.detach()) + P_h( torch.relu(f(xk, yk.detach()) - fz) ) + P_h(g(xk, yk.detach())) 
        dx = grad(phik, xk)[0]
        xk.data -= x_lr * dx 

        T += time.time() - t0 
        monitor.append({
            "k": k, "time": T,
            "F": F(xk, yk).detach().numpy(), "f": f(xk, yk).detach().numpy(), 
            "g": g(xk, yk).detach().numpy(),
            "dx": metric_x(xk, yk).detach().numpy(), 
            "dy": metric_y(xk, yk).detach().numpy() 
        })

        if verbose>=2: 
            print(f"{k:3d}-iter, F={F(xk, yk):>8.2f}, g={g(xk, yk):>8.2f}, dx={metric_x(xk, yk):>6.2f}, dy = {metric_y(xk, yk):>6.2f}")

    if verbose >= 1:
        plt.plot(monitor.time, monitor.dx, label=r'$|x - x^*| / |x^*|$')
        plt.plot(monitor.time, monitor.dy, label=r'dist(y, y(x))')
        # plt.semilogy(monitor.time, monitor.dx, label=r'$|x - x^*| / |x^*|$')
        # plt.semilogy(monitor.time, monitor.dy, label=r'dist(y, y(x))')
        plt.legend()
        plt.show()

    monitor.save_csv(resultfile)


# %%
if __name__ == "__main__":
    
    n, P, reg2 = 1000, 3, 1.0
    print(f"Experiment with n = {n}, P = {P}, reg2 = {reg2}")
    run(n = n, P = P, nIter = 2000, reg2 = reg2, verbose=1)

    n, P, reg2 = 1000, 1, 1.0
    print(f"Experiment with n = {n}, P = {P}, reg2 = {reg2}")
    run(n = n, P = P, nIter = 100, reg2 = reg2, verbose=1)
# %%